#include <iostream>
#include <cstdio>
using namespace std;
#define P 998244353
long long n, m, t, ans;
int main(){
	freopen("bpmp.in", "r", stdin);
	freopen("bpmp.out", "w",stdout);
	cin >> n >> m;
	n%=P;m%=P;
	if (n<m){
		t=n;n=m;m=t;
	}
	ans+=(n-1);
	ans+=(((m-1)*(n))%P);
	ans%=p;
	cout << ans;
	return 0;
}
